import pandas as p
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from AnalyzeHistoricalNHTS import getVehicleDataByYear, getPersonDataByYear, getHouseholdDataByYear, get_state_safety_inspection_df_by_year, get_state_safety_inspection_dict
from tqdm import tqdm

def plotText(y_high=3, y_low=1):
    plt.text(1, y_high, 'Households not subject\nto safety inspections', color=noSafetyColor, ha='left')
    plt.text(2, y_low, 'Households subject\nto safety inspections', color=safetyColor, ha='left')

if __name__=='__main__':
    year = 2017
    hhData = getHouseholdDataByYear(year, small=True)
    # safety_inspection_df = get_state_safety_inspection_df_by_year(year)
    # hhData = hhData.merge(safety_inspection_df, left_on='HHSTATE', right_on='State Code')


    hhDataSmall = hhData.loc[:, ['HOUSEID', 'HHSIZE', 'DRVRCNT', 'HasSafety', 'WTHHFIN']]

    hhDataSmall.loc[:, 'Subject to Safety Inspection'] = hhDataSmall.loc[:, 'HasSafety']



    x = 5
    gr = (1 + np.sqrt(5)) / 2
    y = x / gr
    x_lim = [0, 9]
    noSafetyColor = 'xkcd:red'
    safetyColor = 'xkcd:blue'

    fig, ax = plt.subplots(1,1,figsize=(x,y))

    fig, ax = plt.subplots(1, 1, figsize=(x, y))
    ax = sns.lineplot(data=hhDataSmall, x='HHSIZE', y='DRVRCNT', hue='Subject to Safety Inspection', weights='WTHHFIN',
                      palette=[noSafetyColor, safetyColor], legend=False)
    plt.ylabel('Number of Drivers In Household')
    plt.xlabel('Number of People In Household')
    plotText()
    plt.xlim(x_lim)
    plt.title('Household Drivers vs. Household Size\nBy Safety Inspection Requirement')
    plt.savefig('Plots/HouseholdDriversVsHouseholdSize.png', bbox_inches='tight')
    plt.show()
